import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
import random
import gym
import math
from torch.utils.tensorboard import SummaryWriter
from collections import deque, namedtuple
import time

import warnings
warnings.filterwarnings("ignore")

import os
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'

from torch.autograd import Variable
import pandas as pd
from lifelines import CoxPHFitter
from lifelines import AalenAdditiveFitter
from lifelines import KaplanMeierFitter
import lifelines
from scipy.optimize import minimize
from sklearn import linear_model
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
from lifelines import CoxPHFitter
import numpy as np
import pandas as pd

###############################################################################################################
# Unified survival functions: keep your original Aalen / Cox / KM forms and return shapes/usage
###############################################################################################################

def Survival_T_Aalen(time, Y, X, A, Delta):  # Aalen additive hazards (matches your first section)
    """Inputs: time (1D array of times), Y (durations), X (n×p covariates), A (n actions), Delta (event indicators); Output: (1×len(time)) survival estimates; Process: fits an Aalen additive hazards model on sign-adjusted covariates and predicts survival at 'time'."""
    A = np.reshape(A, (A.shape[0], 1))
    A = 2 * A - 1
    X_adjusted = np.multiply(X, A.reshape(-1, 1))  # (2A - 1) * X
    data = pd.DataFrame(X_adjusted, columns=['X1*A', 'X2*A'])
    pre_data = pd.DataFrame(X_adjusted, columns=['X1*A', 'X2*A'])

    data['Y'] = Y.flatten()
    data['Delta'] = Delta.flatten()

    pre_data['Y'] = Y.flatten()
    pre_data['Delta'] = Delta.flatten()

    cox = AalenAdditiveFitter()
    cox.fit(data, duration_col='Y', event_col='Delta')
    st_estimate = cox.predict_survival_function(pre_data, times=time)
    st_estimate_values = st_estimate.values.reshape(1, -1)
    return st_estimate_values


def Survival_T_Cox(time, Y, X, A, Delta):  # CoxPH (matches your second section)
    """Inputs: time (1D array), Y (durations), X (n×p covariates), A (n actions), Delta (events); Output: (len(time)×n) survival matrix (values); Process: fits CoxPH on sign-adjusted covariates and returns predicted survival over 'time'."""
    A = np.reshape(A, (A.shape[0], 1))
    A = 2 * A - 1
    X_adjusted = np.multiply(X, A.reshape(-1, 1))  # (2A - 1) * X
    data = pd.DataFrame(X_adjusted, columns=['X1*A', 'X2*A'])
    pre_data = pd.DataFrame(X_adjusted, columns=['X1*A', 'X2*A'])

    data['Y'] = Y.flatten()
    data['Delta'] = Delta.flatten()

    pre_data['Y'] = Y.flatten()
    pre_data['Delta'] = Delta.flatten()

    cox = CoxPHFitter()
    cox.fit(data, duration_col='Y', event_col='Delta')
    st_estimate = cox.predict_survival_function(pre_data, times=time)
    return st_estimate.values 


def Survival_C(time, Y, X, A, Delta):  # KM (matches your third section)
    """Inputs: time (scalar or 1D array), Y (durations), X (n×p), A (n actions), Delta (events); Output: KM survival at 'time' as 1D array; Process: fits Kaplan–Meier on Y with censoring and evaluates survival at 'time'."""
    A = np.reshape(A, (A.shape[0], 1))
    data = pd.DataFrame(np.hstack([X, A, np.multiply(X, A)]), columns=['X1', 'X2', 'A', 'X1*A', 'X2*A'])
    data['Y'] = Y.flatten()
    data['1-Delta'] = 1 - Delta.flatten()
    kmf = KaplanMeierFitter()
    kmf.fit(durations=data['Y'], event_observed=data['1-Delta'])
    sc_estimate = kmf.survival_function_at_times(time)
    return sc_estimate.values

###############################################################################################################
# Shared network: keep your original structure, but provide two forward variants (negative/positive softplus) for the three sections
###############################################################################################################

class QR_DQN(nn.Module):
    """A tiny DQN head with one hidden layer and softplus output; neg_softplus flips sign for Aalen/Cox modes."""
    def __init__(self, state_size, action_size, layer_size, seed, neg_softplus=True):
        """Inputs: state_size (tuple), action_size (int), layer_size (int), seed (int), neg_softplus (bool); Output: module instance; Process: builds a 1-layer MLP head and stores configuration."""
        super(QR_DQN, self).__init__()
        self.seed = torch.manual_seed(seed)
        self.input_shape = state_size[1]
        self.action_size = action_size
        self.head_1 = nn.Linear(self.input_shape, layer_size)
        self.ff_2 = nn.Linear(layer_size, action_size)
        self.neg_softplus = neg_softplus

    def forward(self, x):
        """Inputs: x (batch×features tensor); Output: (batch×action_size) Q-like scores; Process: ReLU hidden then softplus head, optionally negated."""
        x = torch.relu(self.head_1(x))
        x = nn.functional.softplus(self.ff_2(x))
        if self.neg_softplus:
            x = -x
        return x

###############################################################################################################
# Shared utility functions: same format as your original, reused here
###############################################################################################################

def normalization(data):
    """Inputs: data (array-like); Output: array normalized to [0,1]; Process: min–max scales using global min and max."""
    _range = np.max(data) - np.min(data)
    return (data - np.min(data)) / _range


###############################################################################################################
# Generic environment: switch reward calculation by 'mode' (faithfully reproduces your three code paths)
###############################################################################################################

class CustomEnv(gym.Env):
    """A vectorized 2D-state, binary-action environment that produces survival-based rewards via Aalen/Cox/KM modes."""
    def __init__(self, max_steps, num_agents, mode, time_point):
        """Inputs: max_steps (int), num_agents (int), mode ('aalen'|'cox'|'km'), time_point (float for survival evaluation); Output: env instance; Process: defines action/observation spaces and internal trackers."""
        super(CustomEnv, self).__init__()

        self.action_space = gym.spaces.Discrete(2)
        self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(num_agents, 2), dtype=np.float32)
        self.num_agents = num_agents
        self.time_point = time_point

        self.max_steps = max_steps
        self.current_step = 0
        self.mode = mode  # 'aalen', 'cox', 'km'

        self.seed_value = None

    def seed(self, seed=None):
        """Inputs: seed (int or None); Output: None; Process: stores seed to control subsequent reset sampling."""
        self.seed_value = seed

    def reset(self):
        """Inputs: None; Output: (state ndarray, observe mask); Process: seeds and initializes states, censoring times, and observation mask."""
        if self.seed_value is not None:
            np.random.seed(self.seed_value)
            self.state = np.random.randn(self.num_agents, 2)
            self.censor = np.random.uniform(low=0, high=10*7, size=(self.num_agents))
        else:
            self.censor = np.random.uniform(low=0, high=10*7, size=(self.num_agents))
            self.state = np.zeros((self.num_agents, 2))
        self.current_step = 0
        self.observe = np.ones(self.num_agents)

        return self.state, self.observe

    def step(self, action):
        """Inputs: action (n×1 int array for each agent); Output: tuple (next_state, reward, next_observe, Delta, survival_time1, survival_time2, done); Process: computes reward via chosen survival model, updates state/observe, checks termination."""
        reward, Delta, survival_time1, survival_time2 = self.reward_function(action, self.state, self.observe)
        state_next = self.state_function(self.state, action)
        next_observe = self.observe_function(self.observe, Delta)
        self.state = state_next
        self.observe = next_observe
        done = self.current_step >= self.max_steps
        self.current_step += 1

        return state_next, reward, next_observe, Delta, survival_time1, survival_time2, done

    def observe_function(self, observe, Delta):
        """Inputs: observe (mask), Delta (events for currently observed); Output: next observe mask; Process: sets entries to 0 where an event occurred (Delta==0 given the inverted logic)."""
        next_observe = observe.copy()
        indices_matrix = np.nonzero(next_observe)[0]
        next_observe[indices_matrix[Delta == 0]] = 0
        return next_observe

    def state_function(self, state, action):
        """Inputs: state (n×2), action (n×1); Output: next_state (n×2); Process: applies action-dependent linear transform plus Gaussian noise."""
        A = np.array(action)
        next_state = np.zeros(state.shape)
        for i in range(state.shape[0]):
            beta = np.array([[3 / 4 * (2 * A[i][0] - 1), 0], [0, 3 / 4 * (1 - 2 * A[i][0])]])
            next_state[i] = np.dot(state[i], beta)
        next_state = next_state + np.random.normal(0, 0.25, size=state.shape)
        return next_state

    def reward_function(self, action, state, observe):
        """Inputs: action (n×1), state (n×2), observe (mask); Output: (reward, Delta, survival_time1, survival_time2); Process: simulates event times, builds Y/Delta, and computes rewards via Aalen/Cox (log S) or KM (IPCW-like) at configured time(s)."""
        action = np.array(action)
        action = action[:, 0]
        beta_1 = np.array([0, -2, 1])
        state_1 = np.insert(state, 0, 1, axis=1)

        linear_predictor = (2 * action - 1) * np.dot(state_1, beta_1)
        baseline_hazard = 1
        param = 1 / (np.exp(linear_predictor) * baseline_hazard)
        T = np.random.gamma(shape=1 / param, scale=param)

        survival_time1 = np.mean(param)
        T = np.clip(T, None, 7)
        survival_time2 = np.mean(T)

        Y = np.minimum(T, self.censor)
        Delta = (T <= self.censor).astype(int)
        self.censor = self.censor - Y
        indices = np.where(observe == 1)[0]
        Y = Y[indices]
        Delta = Delta[indices]
        state = state[indices]
        action = action[indices]
        state_1 = state_1[indices]

        if self.mode == 'aalen':
            time = self.time_point
            reward = Survival_T_Aalen(time, Y, np.array(state), np.array(action), Delta)
            reward = np.log(reward + 1e-5)
            return reward, Delta, survival_time1, survival_time2

        if self.mode == 'cox':
            time = self.time_point
            reward = Survival_T_Cox(time, Y, np.array(state), np.array(action), Delta)
            reward = np.log(reward + 1e-5)
            return reward, Delta, survival_time1, survival_time2

        # KM (third section)
        S = Survival_C(Y, Y, np.array(state), np.array(action), Delta)
        reward = np.zeros_like(Y)
        non_zero_indices = Delta != 0
        reward[non_zero_indices] = (Y[non_zero_indices] * Delta[non_zero_indices]) / S[non_zero_indices]
        return reward, Delta, survival_time1, survival_time2

###############################################################################################################
# ReplayBuffer: keep your implementation, while making reward shapes compatible (two shapes)
###############################################################################################################

class ReplayBuffer:
    """Fixed-size buffer to store experience tuples."""
    def __init__(self, buffer_size, batch_size, device, seed, gamma, n_step=1):
        """Inputs: buffer_size (int), batch_size (int), device (torch.device), seed (int), gamma (float), n_step (int); Output: buffer instance; Process: initializes deque memory and n-step queue for multistep returns."""
        self.device = device
        self.memory = deque(maxlen=buffer_size)
        self.batch_size = batch_size
        self.experience = namedtuple("Experience", field_names=["state", "action", "observe", "reward", "next_state", "done"])
        self.seed = random.seed(seed)
        self.gamma = gamma
        self.n_step = n_step
        self.n_step_buffer = deque(maxlen=self.n_step)

    def add(self, state, action, observe, reward, next_state, done):
        """Inputs: single transition pieces; Output: None; Process: pushes into n-step buffer and commits an aggregated experience once n steps are collected."""
        self.n_step_buffer.append((state, action, observe, reward, next_state, done))
        if len(self.n_step_buffer) == self.n_step:
            state, action, observe, reward, next_state, done = self.calc_multistep_return()
            e = self.experience(state, action, observe, reward, next_state, done)
            self.memory.append(e)

    def calc_multistep_return(self):
        """Inputs: None (uses internal n-step buffer); Output: aggregated (s,a,obs,Rn,s',done); Process: computes discounted sum of n rewards and returns first state/action and last next_state/done."""
        Return = 0
        for idx in range(self.n_step):
            Return += self.gamma ** idx * self.n_step_buffer[idx][3]
        return self.n_step_buffer[0][0], self.n_step_buffer[0][1], self.n_step_buffer[0][2], Return, self.n_step_buffer[-1][4], self.n_step_buffer[-1][5]

    def sample(self):
        """Inputs: None; Output: batch tensors (states, actions, observes, rewards, next_states, dones); Process: random-samples experiences and aligns reward shapes to observation masks."""
        experiences = random.sample(self.memory, k=self.batch_size)

        states = torch.from_numpy(np.stack([e.state for e in experiences if e is not None])).float().to(self.device)
        observes = torch.from_numpy(np.stack([e.observe for e in experiences if e is not None])).long().to(self.device)
        actions = torch.from_numpy(np.stack([e.action for e in experiences if e is not None])).long().to(self.device)

        rewards = []
        for e in experiences:
            if e is not None:
                observe_indices = np.where(e.observe == 1)[0]
                reward = np.zeros_like(e.observe, dtype=np.float32)
                t = 0
                for idx in observe_indices:
                    try:
                        reward[idx] = e.reward[0, t]  # Aalen / Cox
                    except Exception:
                        reward[idx] = e.reward[t]  # KM
                    t += 1
                rewards.append(reward)
        rewards = torch.from_numpy(np.stack(rewards)).float().to(device)
        next_states = torch.from_numpy(np.stack([e.next_state for e in experiences if e is not None])).float().to(self.device)
        dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(self.device)

        return (states, actions, observes, rewards, next_states, dones)

    def __len__(self):
        """Inputs: None; Output: current buffer length; Process: returns number of stored experiences."""
        return len(self.memory)

###############################################################################################################
# DQN_Agent: keep your implementation, but allow positive/negative softplus when constructing the networks
###############################################################################################################

class DQN_Agent():
    """Interacts with and learns from the environment."""
    def __init__(self,
                 state_size,
                 action_size,
                 Network,
                 layer_size,
                 n_step,
                 BATCH_SIZE,
                 BUFFER_SIZE,
                 LR,
                 TAU,
                 GAMMA,
                 UPDATE_EVERY,
                 device,
                 seed,
                 neg_softplus=True):
        """Inputs: model and training hyperparameters; Output: agent instance; Process: builds local/target networks, optimizer, replay buffer, and counters."""
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(seed)
        self.device = device
        self.TAU = TAU
        self.GAMMA = GAMMA
        self.UPDATE_EVERY = UPDATE_EVERY
        self.BATCH_SIZE = BATCH_SIZE
        self.Q_updates = 0
        self.n_step = n_step

        self.qnetwork_local = QR_DQN(state_size, action_size, layer_size, seed, neg_softplus=neg_softplus).to(self.device)
        self.qnetwork_target = QR_DQN(state_size, action_size, layer_size, seed, neg_softplus=neg_softplus).to(self.device)

        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=LR)

        self.memory = ReplayBuffer(BUFFER_SIZE, BATCH_SIZE, self.device, seed, self.GAMMA, n_step)

        self.t_step = 0

    def step(self, state, action, observe, reward, next_state, done, writer):
        """Inputs: transition pieces and TB writer; Output: None; Process: stores experience, periodically samples a batch, learns, and logs Q loss."""
        self.memory.add(state, action, observe, reward, next_state, done)
        self.t_step = (self.t_step + 1) % self.UPDATE_EVERY
        if self.t_step == 0:
            if len(self.memory) > self.BATCH_SIZE:
                experiences = self.memory.sample()
                loss = self.learn(experiences)
                self.Q_updates += 1
                writer.add_scalar("Q_loss", loss, self.Q_updates)

    def act(self, state, eps=0.):
        """Inputs: state (n×2 array), eps (ε-greedy); Output: (actions, selected action values); Process: runs policy net, picks greedy or random actions per row."""
        state = np.array(state)
        state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)

        self.qnetwork_local.eval()
        with torch.no_grad():
            action_values = self.qnetwork_local.forward(state)
        self.qnetwork_local.train()
        action_values = action_values.squeeze(0)
        if random.random() > eps:
            max_value, action = torch.max(action_values, dim=1)
            action = action.unsqueeze(1)
            return action, max_value
        else:
            action = torch.randint(low=0, high=action_values.shape[1], size=(action_values.shape[0], 1))
            # print("eps action")
            max_value = action_values.gather(1, action.to(self.device))
            return action, max_value

    def learn(self, experiences):
        """Inputs: sampled experiences (batch); Output: scalar loss value; Process: computes TD targets via target net, MSE loss with local net, optimizes, and soft-updates target."""
        self.optimizer.zero_grad()
        states, actions, observes, rewards, next_states, dones = experiences
        states = states[observes.bool()]
        next_states = next_states[observes.bool()]
        actions = actions[observes.bool()]
        rewards = rewards[observes.bool()]

        Q_targets_next = self.qnetwork_target(next_states).detach().cpu()
        action_indx = torch.argmax(Q_targets_next, dim=1)
        Q_targets_next = Q_targets_next.gather(1, action_indx.unsqueeze(1))

        Q_targets = rewards.unsqueeze(1) + self.GAMMA * Q_targets_next.to(self.device)

        Q_expected = self.qnetwork_local(states).gather(1, actions)
        loss = F.mse_loss(Q_targets, Q_expected)
        loss.backward()
        self.optimizer.step()

        # ------------------- update target network ------------------- #
        self.soft_update(self.qnetwork_local, self.qnetwork_target)
        return loss.detach().cpu().numpy()

    def soft_update(self, local_model, target_model):
        """Inputs: local_model, target_model; Output: None; Process: Polyak-averages parameters with factor TAU into target network."""
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(self.TAU * local_param.data + (1.0 - self.TAU) * target_param.data)

###############################################################################################################
# Training loop: keep your 'run' logic, only adapt reward extraction for shape differences (no behavioral change)
###############################################################################################################

def run(episodes=1000, eps_fixed=False, eps_frames=1e6, min_eps=0.01, num_agents=1000):
    """Inputs: training hyperparameters; Output: (scores, rewards list per agent, rates, rate timeline, survival_times1, survival_times2); Process: runs DQN loop over episodes/steps, collects metrics, and logs to TensorBoard."""
    scores = []  # list containing scores from each episode
    rewards = [[] for _ in range(num_agents)]
    rate = []
    rates = []
    survival_times1 = []
    survival_times2 = []
    scores_window = deque(maxlen=100)  # last 100 scores
    survival_times1_window = deque(maxlen=20)
    survival_times2_window = deque(maxlen=20)
    output_history = []
    frame = 0
    if eps_fixed:
        eps = 0
    else:
        eps = 1
    eps_start = 1
    i_episode = 1
    state, observe = env.reset()
    score = np.zeros(num_agents)
    for i_episode in range(1, episodes + 1):

        while True:
            action, _ = agent.act(state, eps)
            next_state, reward, next_observe, Delta, survival_time1, survival_time2, done = env.step(action)
            agent.step(state, action, observe, reward, next_state, done, writer)
            indices = np.where(observe == 1)[0]
            our_actions = np.array(action).squeeze(1)
            our_actions = our_actions[indices]
            t = 0
            for idx in indices:
                try:
                    rewards[idx].append(reward[0, t])  # Aalen / Cox
                except Exception:
                    rewards[idx].append(reward[t])      # KM
                t += 1
            survival_times1.append(survival_time1)
            survival_times2.append(survival_time2)
            survival_times1_window.append(survival_time1)
            survival_times2_window.append(survival_time2)
            censor_rate = 1 - np.sum(observe) / len(observe)
            rate.append(censor_rate)
            state = next_state
            observe = next_observe
            t = 0
            for idx in indices:
                try:
                    score[idx] = reward[0, t] + score[idx]
                except Exception:
                    score[idx] = reward[t] + score[idx]
                t += 1
            frame += 1

            if eps_fixed == False:
                if frame < eps_frames:
                    eps = max(eps_start - (frame * (1 / eps_frames)), min_eps)
                else:
                    eps = -1

            if done:
                scores_window.append(np.mean(score))  # save most recent score
                scores.append(np.mean(score))  # save most recent score
                rates.append(censor_rate)
                writer.add_scalar("Average100", np.mean(scores_window), frame)
                output_history.append(np.mean(scores_window))
                print('\rEpisode {} \tAEV: {:.2f} \tARED: {:.2f} '.format(i_episode, float(np.sum(survival_times1_window)), float(np.sum(survival_times2_window))), end="")
                state, observe = env.reset()
                score = np.zeros(num_agents)
                break

    return scores, rewards, rates, rate, survival_times1, survival_times2

###############################################################################################################
# General training wrapper: faithfully reproduces the three main scripts' save dirs/seed lists
###############################################################################################################

def train_all_for_mode(mode, seeds, save_prefix):
    """Inputs: mode ('aalen'|'cox'|'km'), seeds (list of ints), save_prefix (str); Output: None (writes results to files); Process: loops over seeds and time points, trains DQN, logs/writes metrics for each configuration."""
    max_steps = 20
    max_steps = max_steps - 1
    episodes = 50
    num_people = 1000
    time_points = [0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 14.0]  # Different time points
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Selected device:", device)

    # KM uses positive softplus; others use negative softplus (consistent with the three original scripts)
    neg_softplus_flag = (mode != 'km')

    for seed in seeds:
        scores_results = {}
        rate_results = {}
        rates_results = {}
        survival_times1_results = {}
        survival_times2_results = {}

        for time_point in time_points:
            print(f"\nRunning with time_point={time_point} and seed={seed}")
            np.random.seed(seed)
            torch.manual_seed(seed)
            global env
            env = CustomEnv(max_steps=max_steps, num_agents=num_people, mode=mode,time_point=time_point)
            env.seed(seed)
            action_size = env.action_space.n
            state_size = env.observation_space.shape

            global writer
            writer = SummaryWriter(f"DQNt_LL_new_1_num_{num_people}_seed_{seed}/")

            global agent
            agent = DQN_Agent(state_size=state_size,
                              action_size=action_size,
                              Network="DDQN",
                              layer_size=64,
                              n_step=1,
                              BATCH_SIZE=32,
                              BUFFER_SIZE=6400,
                              LR=1e-3,
                              TAU=1e-2,
                              GAMMA=0.99,
                              UPDATE_EVERY=1,
                              device=device,
                              seed=seed,
                              neg_softplus=neg_softplus_flag)

            eps_fixed = False
            t0 = time.time()
            scores, rewards, rates, rate, survival_times1, survival_times2 = run(episodes=episodes, eps_fixed=eps_fixed, eps_frames=100, min_eps=0.025, num_agents=num_people)
            t1 = time.time()
            print(f"Training time for num={num_people} with seed {seed}: {round((t1 - t0) / 60, 2)}min")

            scores_results[time_point] = scores
            rate_results[time_point] = rate
            rates_results[time_point] = rates
            survival_times1_results[time_point] = survival_times1
            survival_times2_results[time_point] = survival_times2
            writer.close()

            subfolder_path = f"{save_prefix}_seed_{seed}"
            os.makedirs(subfolder_path, exist_ok=True)

            with open(os.path.join(subfolder_path, f'scores_results_time_{time_point}.txt'), 'w') as file:
                file.write(f'{[float(x) for x in scores]}\n')

            with open(os.path.join(subfolder_path, f'rate_results_time_{time_point}.txt'), 'w') as file:
                file.write(f'{[float(x) for x in rate]}\n')

            with open(os.path.join(subfolder_path, f'rates_results_time_{time_point}.txt'), 'w') as file:
                file.write(f'{[float(x) for x in rates]}\n')

            with open(os.path.join(subfolder_path, f'survival_times1_results_time_{time_point}.txt'), 'w') as file:
                file.write(f'{[float(x) for x in survival_times1]}\n')

            with open(os.path.join(subfolder_path, f'survival_times2_results_time_{time_point}.txt'), 'w') as file:
                file.write(f'{[float(x) for x in survival_times2]}\n')


if __name__ == "__main__":
    # Run the three models with the exact seed settings from your scripts
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Selected device:", device)
    seed = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    # 1) Aalen Additive
    train_all_for_mode(mode='aalen', seeds=seed, save_prefix="results_aah")

    # 2) CoxPH (second script used seeds 1..10)
    train_all_for_mode(mode='cox', seeds=seed, save_prefix="results_cox")
